Skip to content

MojoStorePagedSingleCache for single K/V paged store#372

Closed
NASA1473 wants to merge 2 commits into
dev/m13_ilufrom
single_kv_store
Closed

MojoStorePagedSingleCache for single K/V paged store#372
NASA1473 wants to merge 2 commits into
dev/m13_ilufrom
single_kv_store

Conversation

@NASA1473

Copy link
Copy Markdown
Collaborator

MojoStorePagedSingleCache, a single-tensor variant of MojoStorePagedKVCache that writes only one attribute (key OR value) into a paged cache. Supports both block_table and chunk_metadata paths, with an ixformer backend and accuracy tests.

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces the MojoStorePagedSingleCache operator and its backend implementation IxformerStorePagedSingleCache, allowing the storage of a single attribute (key or value) into a single paged cache. It also includes comprehensive unit tests to verify accuracy and alignment with the full KV store. The review feedback suggests adding an explicit dimension check for the cache tensor to prevent potential indexing errors, as well as a device compatibility check between states and cache to avoid runtime failures.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Returns:
torch.Tensor: Updated ``cache`` after in-place writes.
"""
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The cache tensor is expected to have exactly 4 dimensions (total_phys_blocks, kv_heads, block_size, head_dim). Adding an explicit dimension check for cache (similar to the check for states) will prevent unexpected IndexError when accessing cache.shape[2] and provide a clearer error message.

Suggested change
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."
assert len(states.shape) == 3, "states must be (token_num, kv_head_num, head_dim), please check."
assert len(cache.shape) == 4, "cache must be (total_phys_blocks, kv_heads, block_size, head_dim), please check."

Comment on lines +121 to +122
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

It is highly recommended to verify that states and cache are on the same device. A device mismatch between these tensors will cause runtime failures or silent errors during execution.

Suggested change
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")
if cache.dtype != states.dtype:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to have the same dtype.")
if cache.device != states.device:
raise ValueError("IxformerStorePagedSingleCache requires states and cache to be on the same device.")

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The Ixformer single-cache backend passes the same tensor as both K and V to the underlying paged store, which will corrupt the cache for any mode that actually writes both.

Summary

Adds a new MojoStorePagedSingleCache operator (core + Ixformer backend) for storing only K or only V into a single paged cache, plus accuracy tests. Motivated by cases like SAGE prefill where K and V are stored independently.

Must fix

  • [BLOCKER] Ixformer backend writes states into both K and V cache slots -- mojo_opset/backends/ixformer/operators/kv_cache.py:130-138 -- paged_store_kv_cache_with_block_table(states, states, cache, cache, ...) passes the same tensor for both key and value arguments. Even with store_mode=1, this relies on the kernel ignoring the unused pair; if store_mode=1 does not mean "key-only" (or if it ever changes), this silently double-writes. Confirm the semantics of store_mode=1 and, if it stores only one side, document that explicitly; otherwise pass distinct dummy/None tensors per the kernel contract.

Suggestions

Suggestions (3)
  • [MAJOR] Reference impl is a Python loop over chunks -- mojo_opset/core/operators/kv_cache.py:227-233 -- chunk_metadata.tolist() + per-chunk indexed assignment will be very slow for large chunk counts; fine as a torch reference, but consider noting it is reference-only or vectorizing as MojoStorePagedKVCache likely does.
  • [MAJOR] Backend dtype check is stricter than core -- mojo_opset/backends/ixformer/operators/kv_cache.py:120-121 -- core operator does not require matching dtypes; backend raises. If the kernel truly requires it, fine; otherwise this will surface as backend-specific failures users can't predict from the core API.
  • [MINOR] Mixed assert/raise validation style -- mojo_opset/core/operators/kv_cache.py:206-221 -- core uses assert (stripped under -O) while the Ixformer subclass uses raise ValueError. Pick one; prefer explicit raises for input validation.

Nits

Nits (2)
  • [NIT] Empty __init__ that only calls super().__init__() is redundant -- mojo_opset/core/operators/kv_cache.py:174-177.
  • [NIT] Tests use type(a) is type(b) and raise NotImplementedError to skip; pytest.skip(...) would be clearer -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:451,510.

Notes

  • [CHECK] Verify store_mode=1 semantics in ixf_f.paged_store_kv_cache_with_block_table actually means "store key only" (or "store single"); the diff gives no hint and this is the entire correctness argument for the backend.
  • [CHECK] MojoStorePagedSingleCache is registered for torch somewhere (._registry.get("torch") is used in tests) -- confirm the torch registration exists; not visible in this diff.

@github-actions

Copy link
Copy Markdown

Claude Code Review

Verdict: Request changes -- The ixformer single-cache backend ignores the chunk_metadata path inconsistently and may corrupt the cache by passing states as both K and V to the underlying kernel.

Summary

Adds a new MojoStorePagedSingleCache operator (core torch reference + ixformer backend) that writes a single K-or-V tensor into a paged cache, plus accuracy tests. Motivation appears to be supporting cases (e.g. SAGE prefill) where only one of K/V is stored.

Must fix

  • [BLOCKER] Ixformer backend passes states/cache twice with store_mode=1 -- mojo_opset/backends/ixformer/operators/kv_cache.py:131-138 -- The call passes states, states, cache, cache to paged_store_kv_cache_with_block_table. If store_mode=1 does not in fact restrict writes to a single side, the same tensor will be written twice (harmless for the single cache, but confirm semantics) and any aliasing assumption inside ixformer (distinct K/V buffers) may be violated. Verify store_mode=1 is the correct selector and that aliased K==V / kcache==vcache inputs are supported; otherwise allocate/forward properly.
  • [BLOCKER] Ixformer single-cache rejects chunk_metadata path entirely -- mojo_opset/backends/ixformer/operators/kv_cache.py:124-125 -- The core operator is documented to support chunk_metadata, and the tests in test_store_paged_single_cache exercise it. On ilu this will raise NotImplementedError, breaking that test unless bypass_not_implemented swallows it silently (which then hides backend coverage). Either implement the path or document/skip it explicitly.

Suggestions

Suggestions (3)
  • [MAJOR] Torch reference is a Python for loop over chunks -- mojo_opset/core/operators/kv_cache.py:230-235 -- Same pattern as the existing KV op, but worth confirming this is only used as a correctness reference; if any production path falls back to the torch impl it will be very slow for long sequences.
  • [MAJOR] Mismatch between core mixing rule and ixformer backend -- mojo_opset/core/operators/kv_cache.py:218-221 vs mojo_opset/backends/ixformer/operators/kv_cache.py:126-127 -- Core asserts chunk_metadata and block_table/cu_q_lens/context_kv_lens are mutually exclusive; ixformer requires block_table+context_kv_lens always. Align the contracts so callers do not see different validation by backend.
  • [MINOR] assert for input validation in core op -- mojo_opset/core/operators/kv_cache.py:213,218-221 -- Other new code uses raise ValueError; asserts are stripped under python -O. Use explicit exceptions for consistency.

Notes

  • [CHECK] ixformer-...20260623 wheel name has a future-dated build tag (year 2026) -- confirm this is intentional and not a typo from the previous 20260610 tag. -- .github/workflows/iluvatar_accuracy_ci.yml:73.
  • [CHECK] test_store_paged_single_cache_matches_full_kv_store only compares against the torch reference of the full-KV op, not the backend under test -- verify this is the intended coverage. -- mojo_opset/tests/accuracy/operators/test_kv_cache.py:548-567.

@NASA1473 NASA1473 closed this Jun 25, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants